{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Attention Maps\n", "\n", "Attention maps are visual representations that highlight the areas of an input image that a neural network focuses on when making predictions. They are particularly useful in understanding and interpreting the decision-making process of deep learning models in both classification and segmentation tasks.\n", "\n", "In image classification tasks or segmentation, attention maps reveal which parts of an image are most influential in determining the class label. \n", "\n", "In this notebook we will see how to compute attention maps for a trained UNET model. The goal of this notebook is to understand how attention maps can be computed and how they can be used to interpret the decision-making process of a neural network.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Some imports\n", "\n", "import monai\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from monai.transforms import (\n", " EnsureChannelFirstd,\n", " AsDiscreted,\n", " Compose,\n", " LoadImaged,\n", " Orientationd,\n", " Randomizable,\n", " Resized,\n", " ScaleIntensityd,\n", " Spacingd,\n", " EnsureTyped,\n", " Lambda\n", ")\n", "import os\n", "import tempfile\n", "from utils.decathlon_dataset import get_decathlon_dataloader\n", "from utils.unet import UNET\n", "from utils.train import train\n", "\n", "import matplotlib.pyplot as plt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the data\n", "root_dir = './utils/datasets'\n", "task = \"Task04_Hippocampus\"\n", "\n", "train_loader = get_decathlon_dataloader(root_dir, task, \"training\", batch_size=4, num_workers=2, shuffle=True)\n", "val_loader = get_decathlon_dataloader(root_dir, task, \"validation\", batch_size=4, num_workers=2)\n", "\n", "# Load the model\n", "\n", "model = UNET(in_channels=1, out_channels=3)\n", "\n", "# Load the model weights\n", "\n", "model.load_state_dict(torch.load(\"trained_unet.pth\"))\n", "\n", "# Set the model to evaluation mode\n", "\n", "model.eval()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding Attention maps and GradCAM\n", "\n", "GradCAM (Gradient-weighted Class Activation Mapping) is a technique for visualizing the regions of an input image that a convolutional neural network focuses on when making predictions. Let's break down how it works:\n", "\n", "1. **Target Layer Selection**: We choose a target layer in the network, typically one of the later convolutional layers that captures high-level features.\n", "\n", "2. **Forward Pass**: The input image is passed through the network to obtain predictions.\n", "\n", "3. **Backward Pass**: We perform backpropagation for the target class, computing gradients with respect to the target layer's activations.\n", "\n", "4. **Importance Weighting**: The gradients are globally average pooled to obtain weights for each feature map in the target layer.\n", "\n", "5. **Feature Map Weighting**: These weights are used to scale the corresponding feature maps, emphasizing the important features for the target class.\n", "\n", "6. **Heatmap Creation**: The weighted feature maps are combined to create a heatmap, which is then normalized and ReLU-activated.\n", "\n", "By viewing these visualizations, we can gain insights into the model's decision-making process and identify potential biases or unexpected behaviors in our neural networks which can be particularly useful to understand the model's behaviour on medical images.\n", "\n", "Let's use the model we trained in the previous notebook to make predictions on the validation set and compute the attention maps. We will use the `GradCAM` class to compute the attention maps, read through the `utils/attention_maps.py` file to understand how it works.\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from utils.attention_maps import GradCAM, apply_cam\n", "# Use the model to generate attention maps\n", "target_layer = model.downs[-1].conv[-3] # Using the last ReLU in the last downsampling block\n", "\n", "grad_cam = GradCAM(model, target_layer)\n", "\n", "# Find a sample image with labels\n", "for sample_data in val_loader:\n", " sample_label = sample_data['label'][20:21]\n", "\n", " if sample_label.sum() > 10: # Check if the label contains any positive values\n", " sample_image = sample_data['image'][20:21] # Add batch dimension\n", " break\n", "\n", "# Generate prediction\n", "with torch.no_grad():\n", " sample_prediction = model(sample_image).argmax(dim=1)\n", "\n", "# Generate CAM for each class\n", "for class_idx in range(3): # Assuming 3 classes (background, anterior, posterior)\n", " cam = grad_cam.generate_cam(sample_image, class_idx)\n", " print(f\"Attention map for class {class_idx}\")\n", " apply_cam(sample_image[0], sample_label, cam, class_idx, sample_prediction)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Question\n", "\n", "What do you observe from the attention maps in this example? \n", "\n", "\n", "Try to plot the attention maps for the other samples in the validation set. If possible, try to find a sample that the model made a suboptimal/wrong prediction on and see if the attention maps are able to highlight the regions that the model focused on to make the wrong prediction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Your code here" ] } ], "metadata": { "kernelspec": { "display_name": "DL_labs_GPU", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }